package edu.cmu.minorthird.classify.experiments; import edu.cmu.minorthird.classify.*; import edu.cmu.minorthird.classify.algorithms.trees.DecisionTreeLearner; import edu.cmu.minorthird.util.ProgressCounter; import edu.cmu.minorthird.util.gui.*; import org.apache.log4j.Logger; /** * View result of some sort of train/test experiment. * * @author William Cohen */ public class CrossValidatedDataset implements Visible { static private Logger log = Logger.getLogger(CrossValidatedDataset.class); private ClassifiedDataset[] cds; private ClassifiedDataset[] trainCds; private Evaluation v; public CrossValidatedDataset(ClassifierLearner learner,Dataset d,Splitter<Example> splitter) { this(learner,d,splitter,false); } public CrossValidatedDataset(ClassifierLearner learner,Dataset d,Splitter<Example> splitter,boolean saveTrainPartitions) { Dataset.Split s = d.split(splitter); cds = new ClassifiedDataset[s.getNumPartitions()]; trainCds = saveTrainPartitions ? new ClassifiedDataset[s.getNumPartitions()] : null; v = new Evaluation(d.getSchema()); ProgressCounter pc = new ProgressCounter("train/test","fold",s.getNumPartitions()); log.info("Number of splits: "+s.getNumPartitions()); for (int k=0; k<s.getNumPartitions(); k++) { Dataset trainData = s.getTrain(k); Dataset testData = s.getTest(k); log.info("Split with "+splitter+": training on "+trainData.size()+" and testing on "+testData.size()); Classifier c = new DatasetClassifierTeacher(trainData).train(learner); DatasetIndex testIndex = new DatasetIndex(testData); cds[k] = new ClassifiedDataset(c, testData, testIndex); if (trainCds!=null) trainCds[k] = new ClassifiedDataset(c, trainData, testIndex); v.extend(cds[k].getClassifier(),testData,k); v.setProperty("classesInFold"+(k+1), "train: "+classDistributionString(trainData.getSchema(),new DatasetIndex(trainData)) +" test: "+classDistributionString(testData.getSchema(),testIndex)); log.info("Stored classified dataset"); pc.progress(); } pc.finished(); } private String classDistributionString(ExampleSchema schema, DatasetIndex index) { StringBuffer buf = new StringBuffer(""); java.text.DecimalFormat fmt = new java.text.DecimalFormat("#####"); for (int i=0; i<schema.getNumberOfClasses(); i++) { if (buf.length()>0) buf.append("; "); String label = schema.getClassName(i); buf.append(fmt.format(index.size(label)) + " " + label); } return buf.toString(); } @Override public Viewer toGUI() { ParallelViewer main = new ParallelViewer(); for (int i=0; i<cds.length; i++) { final int k = i; main.addSubView( "Test Partition "+(i+1), new TransformedViewer(cds[0].toGUI()) { static final long serialVersionUID=20080130L; @Override public Object transform(Object o) { //what is this for? - frank //CrossValidatedDataset cvd = (CrossValidatedDataset)o; return cds[k]; }}); } if (trainCds!=null) { for (int i=0; i<trainCds.length; i++) { final int k = i; main.addSubView( "Train Partition "+(i+1), new TransformedViewer(cds[0].toGUI()) { static final long serialVersionUID=20080130L; @Override public Object transform(Object o) { //what is this for? - frank //CrossValidatedDataset cvd = (CrossValidatedDataset)o; return trainCds[k]; }}); } } main.addSubView( "Overall Evaluation", new TransformedViewer(v.toGUI()) { static final long serialVersionUID=20080130L; @Override public Object transform(Object o) { CrossValidatedDataset cvd = (CrossValidatedDataset)o; return cvd.v; } }); main.setContent(this); return main; } public Evaluation getEvaluation() { return v; } public static void main(String[] args) { Dataset train = SampleDatasets.sampleData("toy",false); ClassifierLearner learner = new DecisionTreeLearner(); //ClassifierLearner learner = new NaiveBayes(); CrossValidatedDataset cd = new CrossValidatedDataset(learner,train,new CrossValSplitter<Example>(3),true); new ViewerFrame("CrossValidatedDataset", cd.toGUI()); } }